Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Removed duplicate convolution for DoRA #2153

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

gslama12
Copy link

@gslama12 gslama12 commented Oct 16, 2024

This pull request fixes two problems:

1. Duplicate Convolution in DoRA implementation for ConvNd Layers:
Since the base layer convolution is already computed in layer.py, we don't need to compute it again in the dora.py. Computing it again doubles the FLOPs consumption during the forward pass resulting in significantly higher FLOPs overall. We can pass the result from the base layer computed in layer.py to the forward pass of the _DoraConvNdLayer in dora.py and save computational resources.

2. Bugfix for DoRA regarding Convolutional Layers using the Groups Argument:
CNNs that for example use depthwise separable convolutional layers result in an error when applying DoRA. Adjusting the dimension of the conv_layer in layer.py fixes this issue.

@BenjaminBossan
Copy link
Member

Thanks for this PR. We're aware of this potential inefficiency but I think it's not as easy as re-using the base result. The reasoning is explained here. Back then, we were only dealing with linear layers but I'm certain that the same logic applies to convolutional layers.

The good news is that this optimization is indeed possible if dropout is set to 0 or if we're in eval mode, see #2122. LMK what you think.

@gslama12
Copy link
Author

gslama12 commented Oct 17, 2024

Thanks for the clarification! Would it be possible to apply the dropout for DoRA similar to how LoRA handles it, i.e. in lora_B(lora_A(x)) during the forward pass? It seems like LoRA also uses the pre-dropout information (result) here and does not re-compute the convolution: result = result + lora_B(lora_A(dropout(x))) * scaling (ln 1120 in layer.py)?

Also, what do you think of the fix for convolutional layers using the "groups" argument?

@BenjaminBossan
Copy link
Member

Would it be possible to apply the dropout for DoRA similar to how LoRA handles it, i.e. in lora_B(lora_A(x)) during the forward pass? It seems like LoRA also uses the pre-dropout information (result) here and does not re-compute the convolution

Could you clarify what you mean here, maybe with a small code example? Note that we have to ensure that the implementation sticks with the specification of the original paper.

When we have no dropout, though, we should be able to make the same optimization as in #2122 though.

Also, what do you think of the fix for convolutional layers using the "groups" argument?

I wasn't aware of the groups argument in Conv2d. Quite possibly your solution is correct. Could you give a small example, so that we can build a unit test based on that? Also, could you explain why we only need it for DoRA?

@gslama12
Copy link
Author

gslama12 commented Oct 19, 2024

So the reasoning behind why the DoRA optimization is not possible when we use lora_dropout != 0 was, that the magnitude vector in the code snippet below will receive with pre-dropout information in form of the base layer result.

x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
                        x,
                        lora_A=lora_A,
                        lora_B=lora_B,
                        scaling=scaling,
                        base_layer=self.get_base_layer(),
                        base_layer_result=result
                    )

After looking at the DoRA paper, I can't figure out why this is an issue. If we see DoRA as a magnitude * direction, where LoRA is applied to the direction component, shouldn't we be able to apply LoRA to the direction component using lora_B(lora_A(dropout(x)))? Basically computing the DoRA result as: result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_B(lora_A(dropout(x))) * scaling in dora.py.

Based on the paper, why would we need to compute the full-rank convolution again with a dropout(x) as the input (dora.py, ln 161)?

@BenjaminBossan
Copy link
Member

Note that when we have LoRA+DoRA+dropout, we ensure that dropout is consistently applied to the LoRA part and the "base_result" part. If we use the result from the base layer directly (i.e. base_layer_result in your code), the LoRA-dropout is not applied to it, therefore, the result differs.

@gslama12
Copy link
Author

gslama12 commented Oct 21, 2024

I think I understand your point. But if we look at LoRA (e.g. ln 1120 in layer.py) we see, that we also don't apply the lora_dropout to the result here:

result = self.base_layer(x, *args, **kwargs)
.
.
.
if not self.use_dora[active_adapter]:
    result = result + lora_B(lora_A(dropout(x))) * scaling

So my question is wether the "base_result" part even needs the dropout for DoRA. And if we do need the dropout in the "base_result" part, why do we not need it for LoRA?

@BenjaminBossan
Copy link
Member

Exactly. The result from the base model does not contain any LoRA-dropout applied to the x. However, for the DoRA-part, we need to ensure that the same dropout is applied to the x of base_layer_result as for the LoRA part. In the existing code, we apply dropout to x, then pass x to the DoRA layer, where x is passed to the convolution operation. Therefore, this requirement is met.

In your proposed code, we would instead use the result from the base layer which does not include x with dropout. Therefore, this is a different output and the DoRA calculation would no longer correspond to what is described in the paper.

Only if there is no dropout can we re-use the base result in the way you proposed, as is shown in #2122. (In addition, we also need to take care of potentially existing bias terms, but I left that out to simplify the discussion.)

You can also create a LoRA layer with DoRA and check that the outputs differ when dropout is applied between the old code and the suggested change (fix the seed to ensure that there is no randomness involved).

@gslama12
Copy link
Author

Yes, I understand the reasoning and that my suggestion would produce a different output. I just could not find the reasoning for why the dropout needs to be applied like this in the DoRA paper. But I'll assume you're right.

The reason why I am questioning this is because we do not seem to use the dropout in the same way for LoRA. So my question is:

Why do we not apply the dropout to the result for the base layer for LoRA in the code example?

result = self.base_layer(x, *args, **kwargs) 
.
.
.
if not self.use_dora[active_adapter]:
    result = result + lora_B(lora_A(dropout(x))) * scaling   # why can we use the base_layer result without dropout here?

@BenjaminBossan
Copy link
Member

Okay, I understand now. Our calculation for DoRA is a bit more complicated than if we simply followed the equation from the paper. The reason is that we first calculate the base result from the base model. Thus we need to correct it later, which is why we have the (mag_norm_scale - 1) term, which cannot be directly seen in the paper. I agree that it's not immediately obvious why for that part, we require the dropout being applied.

Trying to fit this into an equation, this is what I get (note that I omitted the scale for simplicity):

image

If we did not have the dropout in the base result part, then the first 2 terms in the final equation, W x and W drop(x) would cancel out, which looks correct. However, it would also mean we would not be able to simplify the last term, (W + B A) drop(x), which looks incorrect.

Not sure if @nbasyl would have time to check this.

@gslama12
Copy link
Author

gslama12 commented Oct 23, 2024

Right! And if we look at the implementation of LoRA (ln 1121 in layer.py), we see that there we also can't simplify the term W x + b + (B A) drop(x) to (W B A) drop(x) + b.

Since the dropout we are talking about here is actually a lora_dropout, I feel like it would be more intuitive, if it was only applied to the LoRA part of the DoRA implementation (i.e. only to the term m/W_n B A drop (x)). Also it would seem strange to have this non-zero difference W x - W drop(x) in the final result equation if we set lora_dropout != 0.

@gslama12
Copy link
Author

Any updates on this?

@BenjaminBossan
Copy link
Member

Not sure if Shih-Yang currently has time to look into this. In the meantime, how about opening a separate PR for the groups issue. We can also open a separate PR along the lines of #2122, since this optimization should definitely work for dropout=0 or training=False.

@gslama12 gslama12 changed the title FIX: Removed duplicate convolution for DoRA & fixed error for ConvNd layers using "groups" FIX: Removed duplicate convolution for DoRA" Nov 5, 2024
@gslama12 gslama12 changed the title FIX: Removed duplicate convolution for DoRA" FIX: Removed duplicate convolution for DoRA Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants